//
// Created by song on 16-11-18.
//

#ifndef C0COMPILER_ASTVISITOR_H
#define C0COMPILER_ASTVISITOR_H


#include <set>
#include "ASTNode.h"

class ASTVisitor{
protected:
    bool defaultAction;
    set<ASTNodeType> postVisitSet = {FunctionDeclar};
    virtual bool constDeclar(ASTNode& node){
        return defaultAction;
    }
    virtual bool varDeclar(ASTNode& node){
        return defaultAction;
    }
    virtual bool arrayDeclar(ASTNode& node){
        return defaultAction;
    }
    virtual bool funDeclar(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool param(ASTNode &node) {
        return defaultAction;
    }
    virtual bool ifStmt(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool relationExpr(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool addSubExpr(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool multiDivExpr(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool functionInvoke(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool switchStmt(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool returnStmt(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool whileStmt(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool arraySubscript(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool assignStmt(ASTNode &node, bool postVisit) {
        return defaultAction;
    }
    virtual bool identifier(ASTNode& node){
        return defaultAction;
    }

    virtual bool condition(ASTNode &node, bool postVisit) {
        return defaultAction;
    }

    virtual bool caseStmt(ASTNode &node, bool postVisit) {
        return defaultAction;
    }


    virtual bool defaultStmt(ASTNode &node, bool postVisit) {
        return defaultAction;
    }

    virtual bool sourceCode(ASTNode &node, bool postVisit) {
        return defaultAction;
    }

    virtual bool number(ASTNode &node) {
        return defaultAction;
    }

    virtual bool charConst(ASTNode &node) {
        return defaultAction;
    }

    virtual bool stringConst(ASTNode &node) {
        return defaultAction;
    }

    virtual bool block(ASTNode &node, bool postVisit) {
        return defaultAction;
    }

public:


    virtual bool visit(ASTNode &node, bool isPostVisit) {
        if(isPostVisit){
            if(postVisitSet.count(node.astType)==0){
                return defaultAction;
            }
        }
        switch(node.astType){
            case AssignStmt: return assignStmt(node, isPostVisit);
            case ArraySubscript: return arraySubscript(node, isPostVisit);
            case WhileStmt: return whileStmt(node, isPostVisit);
            case ReturnStmt: return returnStmt(node, isPostVisit);
            case SwitchStmt: return switchStmt(node, isPostVisit);
            case SourceCode: return sourceCode(node, isPostVisit);
            case Number: return number(node);
            case Identifier:return identifier(node);
            case CharConst: return charConst(node);
            case StringConst: return stringConst(node);
            case ConstDeclar: return constDeclar(node);
            case VariableDeclar: return varDeclar(node);
            case ArrayDeclar: return arrayDeclar(node);
            case FunctionDeclar:return funDeclar(node, isPostVisit);
            case Param: return param(node);
            case IfStmt: return ifStmt(node, isPostVisit);
            case CaseSegment: return caseStmt(node, isPostVisit);
            case DefaultSegment: return defaultStmt(node, isPostVisit);
            case RelationExpr: return relationExpr(node, isPostVisit);
            case AddSubExpr: return addSubExpr(node, isPostVisit);
            case MultiDivExpr: return multiDivExpr(node, isPostVisit);
            case FunctionInvoke:return functionInvoke(node, isPostVisit);
            case ConditionExpr:return condition(node, isPostVisit);
            case Block: return block(node, isPostVisit);
        }
        return defaultAction;
    }

    void startTraverse(ASTNode& root){
        walkAST(root);
    }

    void walkAST(ASTNode& root){
        bool go = visit(root, false);
        if(go){
            walkChildren(root);
        }
        visit(root, true);
    }

    void walkChildren(ASTNode& root){
        if(!root.isLeaf()) {
            vector<ASTNode *>::const_iterator iter;
            for (iter = root.children.begin();
                 iter != root.children.end(); iter++) {
                walkAST(**iter);
            }
        }
    }
};


#endif //C0COMPILER_ASTVISITOR_H
